Oh man, this is a lot more impressive than my brief skim of the paper made it out to be! I especially really like this graph
This is a casual thought and by no means something I've thought hard about - I'm curious whether b is a lagging indicator, which is to say, there's actually more magic going on in the weights and once weights go through this change, b catches up to it.
Another speculative thought, let's say we are moving from 4* -> 5* and |W_3| is the new W that is taking on high magnitude. Does this occur because somehow W_3 has enough internal individual weights to jointly look at it's two (new) neighbors' W_i`s roughly equally?
Does the cos similarity and/or dot product of this new W_3 with its neighbors grow during the 4* -> 5* transition (and does this occur prior to the change in b?)
The change in the matrix W and the bias b happen at the same time, it's not a lagging indicator.
Question about the gif - to me it looks like the phase transition is more like:
4++- to unstable 5+- to 4+- to 5-
(Unstable 5+- seems to have similar loss to 4+-).
Why do we not count the large red bar as a "-" ?
Good question. What counts as a "-" is spelled out in the paper, but it's only outlined here heuristically. The "5 like" thing it seems to go near on the way down is not actually a critical point.
Would you be willing to share the raw data from the plot - "Developmental Stages of TMS", I'm specifically hoping to look at line plots of weights vs biases over time.
Thanks.
Hello @RGRGRG, yes, I can share the raw data for that plot. If you can direct message me your email address or any other way for communicating a JSON file, I can send them to you.
Also, the developmental stages of this model (see the Setup section) is quite robust as well. If you wish to reproduce this trajectory, starting from the 4++-
initial configuration (see "TMS critical points are k-gons" section) will likely produce similar trajectory.
Where do the bewildering and intricate structures of Nature come from? What purpose do they serve? In his famous 1917 book "On Growth and Form" the Scottish biologist and mathematician D'Arcy Wentworth Thompson wrote the following about the geometric forms of Phaeodaria, shown above:
Chris Olah, who pioneered mechanistic interpretability for neural networks and is somewhat of an intellectual descendant of D'Arcy Wentworth Thompson, has often cited biology as an inspiration for his work. Indeed, in many respects trained neural networks are more similar to biological organisms than to traditional computer programs. One of the most striking parallels is that, just as in biology, structure often forms over neural network training in what appear to be developmental stages.
In 2022, Elhage et al. introduced the Toy Model of Superposition (TMS) to study when and how models represent more features than they have dimensions, aiming to discern a method to systematically enumerate all features in a model to improve interpretability. Here, we show one case of this model: a neural network trying to reconstruct six-dimensional sparse inputs by encoding them non-linearly in two dimensions. A visualization of the encoding strategy (weights of the neural network) is shown in the top left, and it is visible that over training the network passes through three "forms" or stages of development (labelled 4++−, 4+, 5 for reasons that will be explained below), separated by sudden transitions where the structure changes. From here on we will refer to this animation as the Developmental Stages of TMS.
What are these stages of development and what kind of process are these transitions? Can we understand why the training process passes through these particular stages in this particular order? Does this toy model reveal a path towards a better understanding of the bewildering and intricate structure of larger neural networks?
In Chen et al (2023), it is shown that the stages of development in TMS can be concretely understood from the perspective of singular learning theory. These stages (or forms) are synonymous with critical points of the loss landscape L(w), and the development between stages (growth) corresponds to phase transitions between these critical points, where the only dynamically-permissible transitions involve decreasing loss L(w) and increasing model complexity λ.
Let us elaborate.
On Growth and Form
We are familiar with the idea that a biological system might, over the course of its development from an embryo, take various forms that are stable and distinct. Like many familiar notions, the concept of a form is difficult to define in a careful way:
Obviously.
For systems like neural networks whose development is governed by the process of gradient descent ˙w=−∇L(w) for a potential L(w), a reasonable mathematical formalisation of the informal notion of a form is the concept of a critical point. A network parameter w∗∈W is a critical point if all partial derivatives of the loss vanish, that is, ∇L(w∗)=0. It's a place where gradient descent slows down, or even stops.
Local minima and maxima in two dimensions are critical points familiar from high school calculus, as are saddles in three dimensions (with both increasing and decreasing directions nearby). We use three different words maxima, minima, saddle to distinguish these three kinds of local geometry.
For potentials L(w) like the loss function of a neural network, there are many more kinds of critical points than these, each with their own particular local geometry. It is this geometry, and the configuration of these critical points relative to one another, which dominates the large scale behaviour of dynamical systems (Gilmore, in his text on Catastrophe Theory, calls this the Crowbar Principle).
To return then to the biological setting, we may identify forms with critical points and growth with flows between neighbourhoods of critical points.
In the Toy Model of Superposition, the forms are regular polygons, and we observe that growth is restricted to occur in a specific way: loss must go down, and complexity must go up. What we see in the specific 4++−→4+→5 transition of the animation is actually a more general phenomena, at least in the setting of TMS, as the complexity of the polygon increases in each transition. This observation connects the dynamical transitions of neural network training to the Singular Learning Theory (SLT) theory of phase transitions. Let's dive in.
The Zoo of TMS Critical Points
In the high-sparsity limit of the TMS potential L(w), it is possible to explicitly calculate and study critical points. The problem of classifying these critical points is similar in some respects to problems in the theory of tight frames, compressed sensing, and the Thomson problem. In the case of r=2 feature dimensions, these critical points have a clear, interpretable meaning: they correspond to k-gons.
The Setup
We consider a two-layer ReLU autoencoder with input and output dimension c, with r≤c hidden dimensions in the network:
f:X×W→Rcf(x,w)=ReLU(WTWx+b),where w=(W,b)∈W. Here W=[W1,…,Wc] is a matrix of c column vectors Wi∈Rr, and b∈Rc is a bias.
In the high-sparsity limit, an input sample x∈X=[0,1]c from the true data-generating distribution q(x) has the form x=μei, where ei is the ith basis vector drawn uniformly from i∼{1,…,c}, and μ is uniformly sampled from μ∼[0,1].
The objective of the TMS learning machine is to find an efficient method to compress the high-dimensional true input distribution q(x) into a lower-dimensional representation, in other words, to approximately reconstruct any input f(x,w)≈x using less feature dimensions r than input dimensions c. So for any dataset of samples Dn={x1,…,xn}, the empirical TMS loss to minimise during training is
Ln(w)=1nn∑1=1∥xi−f(xi,w)∥2.We refer to L(w)=E[Ln(w)] as the TMS potential, for which a closed form expression is given in Lemma 3.1 of Chen et al. (2023).
TMS Critical Points are k-gons
When we set the number of feature dimensions to r=2, meaning the autoencoder has cweight vectors Wi∈R2 in the plane to perform its compression, the most interesting low-loss critical points[1] are regular polygons — triangles, squares, pentagons and the like. Each critical point w∗=(W∗,b∗) is characterised by three quantities:
We denote each critical point as a kσ+,ϕ−-gon, which we will call k-gon for short.[2] These are the fundamental forms of the Toy Model of Superposition.
In the r=2,c=6 case, we empirically catalogued and theoretically proved the existence of 18 low-loss critical points.[3] Since we have a closed form expression for the TMS potential L(w), we can easily plug in each identified critical point to find its loss and create a diagram akin to the energy-levels of different states in physics.
SGD Plateaus at k-gons
The Developmental Stages of TMS animation shows SGD settling at three different plateaus through training, which we visually saw corresponded to different "forms" of the parameter w=(W,b). These forms are different k-gon critical points, each with their own loss plateau.
This is not a one-off.
No matter where SGD is initialised, it always plateaus on a known k-gon critical point. Sometimes it even transitions through multiple plateaus, like we see in the animation.
So, in TMS, SGD training trajectories can be thought of as a developmental journey through different k-gon forms. The rapid drops in loss from plateau to plateau signify instances of actual growth. For example, the model literally develops a new limb when it transitions from a 4+-gon to a 5-gon.
Opposing Staircases: Loss ↓, Complexity ↑
Let's take a closer look at the green line in the Developmental Stages of TMS animation, which we have claimed is measuring the "local complexity" ^λ(t) of the model at time t. (Why we care about the local complexity ^λ will be explained below).
Notice how as loss goes down over training, the local complexity goes up, as if they're opposing staircases.
In fact, if we look back at the colourful SGD Plateaus plot above, there seems to be an inverse relationship between the energy L(w∗α) of a critical point w∗α and its estimated complexity ^λα, suggesting a more general phenomena than just the 4++−→4+→5 transition we see here.
One can plot (^λα,L(w∗α)) for different k-gon critical points w∗α, and then connect the pairs between which there were observed phase transitions in SGD. In doing this we repeatedly observe this same pattern: structure seems to form in particular ways, consistently decreasing loss while increasing complexity. Some transitions are dynamically permissible, and others aren't.
This single plot of loss against complexity captures the growth and form of TMS.
What explains these highly restricted developmental trajectories? Training often skips past critical points with intermediary energy levels, so it can't simply be a case of moving further and further down the loss landscape. Something else is going on.
Singular learning theory (SLT) can shed some light.
TMS meets SLT
Bayesian Phases are also Critical Points
Phases Minimise the Free Energy
Those already familiar with the general story we have been telling will recognise that there is a better name for these "forms" that we observe throughout TMS training. They are phases. Moreover, the sharp drops between plateaus are indicative of growth in structure, and these are phase transitions.
Plateaus in loss are sometimes referred to informally as phases in the machine learning literature. To get to the heart of this, we turn to the Bayesian perspective on learning. Here the posterior p(w|Dn)∝e−nLn(w)φ(w) (for some prior φ(w) on parameter space W⊂Rd) is the central object, containing all information about our system for any sample size n.
In statistical physics and Bayesian learning theory a phase is some region Wα of state/parameter space W that has concentrated posterior mass — in other words, a configuration of the system that the posterior deems likely to occur. This posterior mass is measured (for tractability reasons) using the free energy Fn(Wα) [4]. The Bayesian model selection process is then guided by the principle of free energy minimisation (equivalently, posterior mass maximisation).
So are Bayesian phases also related to critical points of L(w)? They sure are.
Geometric Signatures of Bayesian Phases
Let Wα be a region of parameter space W⊂Rd that contains a critical point w∗α, where w∗α locally minimises the loss L(w)=E[Ln(w)] in Wα. Then as the number of samples n tends to infinity, Watanabe's free energy formula tells us that the local free energy of Wα is asymptotically given by
Fn(Wα)≈nLn(w∗α)+λαlogn+cαasn→∞.Here nLn(w∗α) is the loss (aka potential energy), λα∈Q>0 is the local learning coefficient (aka model complexity, aka the RLCT), and cα includes a term of order loglogn and a constant order term which incorporates information such as the prior-volume (or equivalently in this setting, the weight-norm ∥w∥2). [5]
This formula tells us that the posterior concentrates on neighbourhoods of critical points Wα, which we will now genuinely call phases in good conscience. Remember, singular models like neural networks induce loss landscapes L(w) that have a diverse set of critical points with differing tradeoffs between loss and complexity, unlike regular models where complexity λ is constant across parameter space, λ=d2.
That these local geometries vary across parameter space is what gives singular models their rich phase structure. This is exactly what we saw with the diverse catalogue of TMS k-gon phases, each having different loss-complexity tradeoffs (^λα,L(w∗α)).
What does SLT say about phase transitions, though?
Bayesian Phase Transitions
Internal Model Selection
The different tradeoffs between loss and complexity across phases result in what is effectively internal model selection. The posterior prefers high-loss-low-complexity models Wα in the face of insufficient evidence (low n) — a kind of Occam's Razor — but, as you increase your evidence with more training samples n, the posterior becomes more certain that low-loss-high-complexity models Wβ are better.
This process forms the basis of Bayesian phase transitions, which occur when there is a sudden change in which region of parameter space has the dominant posterior mass (equivalently, lower free energy).
Suppose we carve up (or more precisely, coarse grain) our parameter space W⊂Rd into a finite set of disjoint phases {Wα}α covering W. (Precisely how you coarse grain parameter space to perform maximally interesting inference is a non-trivial process - in TMS this is natural, in other settings it may not be).
By the log-sum-exp-approximation, the global free energy Fn defined by e−Fn=∑αe−Fn(Wα) is approximately
Fn=−log∑αe−Fn(Wα)≈minαFn(Wα)≈minα[nLn(w∗α)+λαlogn+cα],meaning that it is dominated by the phase Wα with the lowest free energy for a given n value, while the rest are exponentially suppressed. This is the sense in which the posterior "chooses" the phase Wα at a given n.
Phase Transitions from High-Loss-Low-Complexity to Low-Loss-High-Complexity
So when does the dominant phase change from the point of view of the free energy? In the simplest case, say we had two phases α and β, defined by neighbourhoods Wα and Wβ centred at critical points w∗α and w∗β respectively. Let's suppose that β has better loss but higher complexity, so
ΔL=Ln(w∗β)−Ln(w∗α)<0andΔλ=λβ−λα>0,(and suppose that cα−cβ=0 for ease). Then as n increases, a Bayesian phase transition from α→β will occur at n≈ncrit when the ordering of their free energies flips. SolvingFn(Wα)=Fn(Wβ) shows that ncrit is the unique solution to [6]
nΔL+Δλlogn=0.We say there is a Bayesian phase transition of Type A at n=ncr.
Are visions of opposing staircases flashing before your eyes?
Perhaps Dynamical Transitions have Bayesian Antecedents
A Bayesian transition occurs as we increase sample size n, whereas dynamical transitions occur in SGD time t. We know that both are governed by the critical points L(w). But is there a deeper connection going on?
The Toy Model of Superposition suggests there is. In both settings, transitions occur from high-loss-low-complexity k-gons to low-loss-high-complexity k-gons, suggesting that any dynamical system has a Bayesian phase transition "standing behind it", which we will call a Bayesian antecedent.
This observation leads us to put forward the following hypothesis:
When operationalised in the TMS setting for a transition α→β in time, and assuming that only Type A transitions occur, the BAH implies:
This was not falsified in the TMS experiments.
Do Bayesian Phase Transitions even exist? Yes!
"Wait a second, didn't you say the free energy formula is an asymptotic approximation? Why should I believe it for finite n, let alone moderately sized n? Do these Bayesian phase transitions even exist, empirically?"
Yes, yes they do.
Returning to our catalogue of k-gon critical points, we are actually able to theoretically calculate the learning coefficient of the 5- and 6-gons[7]:
Running the numbers, the free energy formula theoretically predicts that there will be a phase transition at ncrit=601, which we can verify empirically by using MCMC to sample from the TMS posterior distribution at a range of dataset sizes n.
Crucially, because we know the critical points in TMS are k-gons, we are able to very naturally coarse-grain parameter space into phases {Wα}α, thereby assigning any posterior sample w=(W,b) into a phase. This feature of TMS allows us to calculate the relative frequency of each phase for different n, i.e. the probability mass
P(w∈Wα∣Dn)=e−Fn(Wα)∑βe−Fn(Wβ)for each phase Wα. Plotting these relative frequencies over a range of n gives us phase occupancy curves, allowing us to compare the theoretical occupancy curve (determined by the theoretical Ln(w∗α),λα,cα) and the empirically derived version from MCMC experiments.
The most salient feature of the curves is the Bayesian phase transition from a5-gon→6-gon which happens at the crossover around ncrit≈600, seen in both the theoretical and empirical curves (and occurring at remarkably similar values of n). Since the 6-gon is the global minimum of the TMS potential L(w) for r=2,c=6, it is not surprising to see that it is the eventual dominant phase for sufficiently large n. But just as SLT predicts, since the 5-gon has lower complexity, it can be the dominant phase for some range of n, which we see here for n∈[100,600].
It's worth noting that for low n, a lot of this analysis does start to break down. This is primarily because the lower order terms of the asymptotic free energy formula become increasingly more important, and classifying each sample into a phase becomes less precise due to a more diffuse posterior. [8]
To guard against some of these issues, one can manually inspected MCMC samples using a 2D t-SNE projection to verify that the phase classifications do cluster together appropriately.
Updates and where to now?
The work presented here represents the first step toward not just verifying some of the key tenets of the developmental interpretability agenda but also advancing toward a science of deep learning through the thermodynamic lens of singular learning theory.
Some updates we took away from working on this:
It's worth pausing here for a moment to consider the parts of this work that are specific to TMS ("mechanistic" claims, specific to the problem at hand) and those that we expect to hold more broadly ("thermodynamic" claims, true across general learning machines).
Here are some examples:
Obviously we don't expect to see k-gons as critical points in general networks. These structures/forms/phases will manifest in many different ways. But we do expect them to form in phase transitions as a universal phenomena.
Some of the guiding questions from here include:
We believe that understanding the growth and form of neural networks is fundamental to a science of deep learning, and that this understanding will come from a combination of deep mathematics and careful study of many particular systems, like TMS. In this we take inspiration from D'Arcy Wentworth Thompson, who had a passion not only for mathematics, but for collating many beautiful examples of organism development from across the natural world.
Here we are slightly abusing the term "critical point" - in general, we are actually referring to critical submanifolds. In the case of the TMS potential, there are two generic symmetries that mean the different "critical points" are actually critical submanifolds:
Permutation symmetry, where, like any other feedforward ReLU neural network we can permute hidden nodes (i.e. simultaneously permute weight columns and bias entries) and still yield the same function, and;
Rotation symmetry, (or more precisely, O(r)-invariance, where O(2) is rotation in the plane), since for any orthogonal matrix O∈Rr×r we have
f(x,(OW,b))=ReLU(WTOTOWx+b)=ReLU(WTWx+b)=f(x,(W,b)).
These generic symmetries mean that singular learning theory is an essential tool to fully understand the TMS system.
There are some proven constraints on what combinations of (k,σ,ϕ) are possible - for more details, see Appendix A of Chen et al. (2023).
While we have not proven that these are all of the possible critical points (and indeed, there are likely many more high-loss critical points that SGD never cares about), we are confident that we found all of the important ones.
For a given region of parameter space Wα⊂W the free energy is
Fn(Wα)=−log(∫Wαe−nLn(w)φ(w)dw),thus measuring the posterior concentration of Wα.
To be clear, this is the free energy to leading order. It is an active problem to determine lower order terms of the free energy, investigating which other geometric quantities are important for model selection.
The precision of this ncrit value is not to be taken too literally - it is, after all, based on an asymptotic approximation. Interpreting such asymptotics require great care for any finite-n system (which is all we ever care about in real systems).
In fact, we actually have a reasonably good idea of learning coefficients for various 4σ+,ϕ−-gons too, however these remain unpublished due to some technical non-analyticity complications.
Such analysis also depends on the coarse graining of the phases, and while we are quite sure we have located all low-energy critical points, it is still possible that the presence of other unaccounted for phases could introduce inaccuracies in these occupancy curves.